import numpy as np
import torch
import torchattacks
import torchvision.transforms as T
from torch.utils.data import DataLoader

import attacks
from datasets import NIPS2017AdversarialCompetition
from constants import BATCH_SIZE, DEVICE
from models import models, sizes, intermediate_layer_names
from utils import de_normalization, imshow


def main():
    # transforms
    transforms = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    # dataset
    dataset = NIPS2017AdversarialCompetition(transform=transforms, requires_grad=True)
    dataLoader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    # for source_model_name in ['resnet152']:
    for source_model_name in ['inception_v3', 'inception_resnet_v2', 'resnet152', 'vgg16']:
        # models
        source_model = models[source_model_name].eval().to(DEVICE)
        # intermediate_layer_name = intermediate_layer_names[source_model_name]
        target_models = models
        # attack
        # attack = attacks.IFGSM(source_model, eps=16 / 255, steps=10)
        # attack = attacks.PGD(source_model, eps=16 / 255, steps=10)
        # attack = attacks.DIM(source_model, eps=16 / 255, steps=10, decay=1.0, resize_rate=0.9, diversity_prob=0.5)
        # attack = attacks.TIM(source_model, eps=16 / 255, steps=10, decay=1.0, kernel_name='gaussian', len_kernel=15, nsig=3)
        # attack = attacks.SIM(source_model, eps=16 / 255, steps=10, decay=1.0, m=5)
        # attack = attacks.SINIFGSM(source_model, eps=16 / 255, steps=10, decay=1.0, m=5)
        # attack = attacks.MIFGSM(source_model, eps=eps, steps=steps, decay=1.0)
        # attack = attacks.NIFGSM(source_model, eps=eps, steps=steps, decay=1.0)
        # attack = attacks.VMIFGSM(source_model, eps=16 / 255, steps=10, decay=1.0, N=20, beta=1.5)
        # attack = attacks.VNIFGSM(source_model, eps=16 / 255, steps=10, decay=1.0, N=20, beta=1.5)
        # attack = attacks.FIA(source_model, eps=16 / 255, steps=10, decay=1.0, drop_probability=0.3, ensemble_number=30,
        #                      intermediate_layer_name=intermediate_layer_name)
        # attack = attacks.FIEPNIFGSM(source_model, eps=16 / 255, steps=10, decay=1.0, epochs=5, drop_probability=0.3,
        #                             ensemble_number=30, intermediate_layer_name=intermediate_layer_name)
        # attack = torchattacks.PGD(source_model, eps=16 / 255, alpha=1.6 / 255, steps=10, random_start=True)
        # attack = attacks.EMIM(source_model, eps=16 / 255, steps=10, decay=1.0, epochs=5)
        attack = attacks.EPNA(source_model, eps=16 / 255, steps=10, decay=0.0, epochs=0)
        # attack = attacks.EPNDIM(source_model, eps=16 / 255, steps=10, decay=1.0, epochs=5)
        # attack = attacks.EPNTIM(source_model, eps=16 / 255, steps=10, decay=1.0, epochs=5, kernel_name='gaussian',
        #                         len_kernel=15, nsig=3)
        # attack = attacks.EPNSIM(source_model, eps=16 / 255, steps=10, decay=1.0, epochs=5, m=5)
        # attack = attacks.VTEPNIFGSM(source_model, eps=16 / 255, steps=10, decay=1.0, N=20, beta=1.5, epochs=5)
        #
        counter = 0
        predict_labels = torch.load('resources/ground_truth/predict_labels.pt')
        ground_truth = torch.load('resources/ground_truth/ground_truth.pt')
        adv_predict_labels = [[] for _ in range(len(target_models))]
        accuracy = torch.zeros((len(target_models),), dtype=torch.float, device=DEVICE)
        asr = torch.zeros((len(target_models),), dtype=torch.float, device=DEVICE)

        for images, labels in dataLoader:
            images = images.to(DEVICE)
            if sizes[source_model_name] != 299:
                images = T.Resize([sizes[source_model_name], ])(images)
            labels = labels.to(DEVICE)

            adv = attack(images, labels)

            with torch.no_grad():
                for index, (model_name, target_model) in enumerate(target_models.items()):
                    if sizes[source_model_name] == sizes[model_name]:
                        adv_resized = adv.clone()
                    else:
                        adv_resized = T.Resize([sizes[model_name], ])(adv)
                    logits = target_model(adv_resized)
                    predicts = torch.max(logits, dim=1)[1]
                    for pre in predicts:
                        adv_predict_labels[index].append(pre.item())

            counter += BATCH_SIZE
            if counter % 100 == 0:
                print(f'Loop: {counter}')

        adv_predict_labels = torch.from_numpy(np.array(adv_predict_labels))
        ave_asr = 0
        for index in range(len(target_models)):
            predicts = predict_labels[index]
            adv_predicts = adv_predict_labels[index]
            accuracy_num = torch.sum((predicts == ground_truth).int())
            attack_success_num = torch.sum(((predicts == ground_truth) & (adv_predicts != ground_truth)).int())
            accuracy[index] = accuracy_num / len(dataset)
            asr[index] = attack_success_num / accuracy_num
            if list(target_models.keys())[index] != source_model_name:
                ave_asr += asr[index].item()
        ave_asr /= len(target_models) - 1
        # print(f'accuracy: {accuracy}')
        print(f'| {source_model_name}    ' +
              f'| {asr[0].item():>.4f}    | {asr[1].item():>.4f} | {asr[2].item():>.4f} | {asr[3].item():>.4f}    '
              + f'| {asr[4].item():>.4f} | {asr[5].item():>.4f} | {asr[6].item():>.4f} | {asr[7].item():>.4f}  '
              + f'| {asr[8].item():>.4f}  | {asr[9].item():>.4f} | {asr[10].item():>.4f} | {asr[11].item():>.4f} '
              + f'| {asr[12].item():>.4f} | {asr[13].item():>.4f}    | {asr[14].item():>.4f}    | {asr[15].item():>.4f}    '
              + f'| {asr[16].item():>.4f}    | {asr[17].item():>.4f}     | {asr[18].item():>.4f}        '
              + f'| {ave_asr:>.4f}  |')


if __name__ == '__main__':
    main()
